Replace AVG_POOL2D with REDUCE_SUM in DecomposeMeanDimPass (#19242)#19242
Replace AVG_POOL2D with REDUCE_SUM in DecomposeMeanDimPass (#19242)#19242meta-codesync[bot] merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19242
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Pending, 4 Unrelated FailuresAs of commit 68c038b with merge base a3dd0fa ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@mcremon-meta has exported this pull request. If you are a Meta employee, you can view the originating Diff in D101418199. |
This PR needs a
|
|
Thanks for the PR, Matthias! I see a few failing tests, could you take a look into them? |
Summary: Replace the avg_pool2d decomposition path in DecomposeMeanDimPass with REDUCE_SUM + MUL(1/N) for all mean.dim reductions. AVG_POOL2D can only pool over spatial (H×W) axes in TOSA/NHWC layout, which forces the compiler to insert TRANSPOSE ops when the reduction is over channels (common in LayerNorm). REDUCE_SUM works on any axis without layout constraints, avoiding those transposes entirely. Reviewed By: 3l1 Differential Revision: D101418199
8dff93f to
b3fffa9
Compare
test_layer_norm.py: Removed randn_last_three_dims and randn_last_three_dims_no_bias from U85 16a8w xfails — the new sum-based decomposition improved accuracy enough that these tests now pass. test_transpose_counts.py: Updated 5 expected TRANSPOSE counts (groupnorm 1→0, groupnorm_channels_last 3→2, model_2 11→9, model_4 5→3, model_5 6→4) — replacing avg_pool2d with sum+mul eliminates NHWC layout conversions. test_cond.py: Xfailed one_arg_two_outputs in both TOSA INT and U85 INT — the new decomposition creates a full constant inside torch.cond branches which PyTorch's constant folder freezes into a parameter the branch submodule can't access. @gggekov please note the last one. Seems like there is an issue with higher order ops and constant folding. I do think it's out of scope for this, so flagging it for later fix! Edit: changed the op from |
3l1
left a comment
There was a problem hiding this comment.
let's review why we fail here before merging
tosa_int_xfails = {
"one_arg_two_outputs": "mean decomposition creates frozen constant inside cond branch that breaks re-export",
}
Summary: Replace the avg_pool2d decomposition path in DecomposeMeanDimPass with REDUCE_SUM + MUL(1/N) for all mean.dim reductions. AVG_POOL2D can only pool over spatial (H×W) axes in TOSA/NHWC layout, which forces the compiler to insert TRANSPOSE ops when the reduction is over channels (common in LayerNorm). REDUCE_SUM works on any axis without layout constraints, avoiding those transposes entirely. Reviewed By: 3l1 Differential Revision: D101418199
b3fffa9 to
c755445
Compare
Summary: Replace the avg_pool2d decomposition path in DecomposeMeanDimPass with REDUCE_SUM + MUL(1/N) for all mean.dim reductions. AVG_POOL2D can only pool over spatial (H×W) axes in TOSA/NHWC layout, which forces the compiler to insert TRANSPOSE ops when the reduction is over channels (common in LayerNorm). REDUCE_SUM works on any axis without layout constraints, avoiding those transposes entirely. Reviewed By: 3l1 Differential Revision: D101418199
c755445 to
68c038b
Compare
Copied from above comment: changed the op from mean to sum in the test_cond file, since it's not the point of the test to care about that op. No xfail added in the end |
Summary:
Replace the avg_pool2d decomposition path in DecomposeMeanDimPass with
REDUCE_SUM + MUL(1/N) for all mean.dim reductions.
AVG_POOL2D can only pool over spatial (H×W) axes in TOSA/NHWC layout,
which forces the compiler to insert TRANSPOSE ops when the reduction is
over channels (common in LayerNorm). REDUCE_SUM works on any axis without
layout constraints, avoiding those transposes entirely.
Reviewed By: 3l1
Differential Revision: D101418199